def get_hyparams_config_class(dataset_name):
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]

class Jango:
    def __init__(self):
        super(Jango, self).__init__()
        self.test_per_step = 5
        self.fine_tune_per_step = 150
        self.training_steps = 2000
        self.cursor_training_steps = 1000
        self.drop_prob = 0.01
        self.learning_rate = 0.002
        self.coeff = 1
        self.h_dim = 35
        self.low_dim = 50
        self.latent_dim = 32
        self.dense_dim = 1000
        self.lstm_layer = 1
        self.weight_decay = 1e-7

        self.random_seed = [0, 1]
        self.rs_list = [[0, 1],[0, 1, 2]]

class Spike:
    def __init__(self):
        super(Spike, self).__init__()
        self.test_per_step = 5
        self.fine_tune_per_step = 100
        self.training_steps = 1200
        self.cursor_training_steps = 1000
        self.drop_prob = 0.01
        self.learning_rate = 0.002
        self.coeff = 1
        self.h_dim = 35
        self.low_dim = 50
        self.latent_dim = 30
        self.dense_dim = 1000
        self.lstm_layer = 1
        self.weight_decay = 1e-7

        self.random_seed = [1, 2]
        self.rs_list = [[0, 1],[0, 1, 2]]

class Chewie:
    def __init__(self):
        super(Chewie, self).__init__()
        self.test_per_step = 10
        self.fine_tune_per_step = 120
        self.training_steps = 2500
        self.cursor_training_steps = 1000
        self.drop_prob = 0.01
        self.learning_rate = 0.002
        self.coeff = 1
        self.h_dim = 32
        self.low_dim = 50
        self.latent_dim = 64
        self.dense_dim = 1000
        self.lstm_layer = 1
        self.weight_decay = 1e-5

        self.random_seed = [1, 2, 3, 4, 5]
        self.rs_list = [[0], [0], [0], [0], [0]]

class Mihili:
    def __init__(self):
        super(Mihili, self).__init__()
        self.test_per_step = 10
        self.fine_tune_per_step = 100
        self.training_steps = 2500
        self.cursor_training_steps = 2000
        self.drop_prob = 0.01
        self.learning_rate = 0.002
        self.coeff = 1
        self.h_dim = 32
        self.low_dim = 60
        # self.latent_dim = 64
        self.latent_dim = 64
        self.dense_dim = 1000
        self.lstm_layer = 1
        self.weight_decay = 5e-7

        self.random_seed = [1, 3, 5, 7, 9]
        self.rs_list = [[0], [0], [0], [0], [0]]

class Mihili_RT:
    def __init__(self):
        super(Mihili_RT, self).__init__()
        self.test_per_step = 10
        self.fine_tune_per_step = 200
        self.training_steps = 3000
        self.cursor_training_steps = 2000
        self.drop_prob = 0.01
        self.learning_rate = 0.002
        self.coeff = 1
        self.h_dim = 32
        self.low_dim = 60
        self.latent_dim = 64
        self.dense_dim = 1000
        self.lstm_layer = 1
        self.weight_decay = 5e-7

        # self.random_seed = [1, 2, 3, 4, 5]
        self.random_seed = [0, 6, 7, 8, 9]
        self.rs_list = [[0], [0], [0], [0], [0]]